#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
为Milvus向量数据库中的标量字段创建索引
主要针对node_type和repo_id字段，以支持过滤查询
"""

import sys
import os
import logging
from pymilvus import MilvusClient

from pathlib import Path

# 添加backend目录到系统路径（仅用于测试脚本）
backend_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_dir))

# 导入配置管理器
from app.utils.config_manager import config
from app.base.logger import setup_logger

# 设置日志
logger = setup_logger("milvus_index_creator")

def create_scalar_index(client, collection_name, field_name):
    """
    为标量字段创建索引
    
    Args:
        client: MilvusClient实例
        collection_name: 集合名称
        field_name: 字段名称
    
    Returns:
        bool: 是否成功创建索引
    """
    index_name = f"{field_name}_idx"
    
    try:
        # 查看集合信息，检查字段是否存在
        collection_desc = client.describe_collection(collection_name)
        fields = collection_desc.get("fields", [])
        field_exists = False
        
        for field in fields:
            if field.get("name") == field_name:
                field_exists = True
                logger.info(f"字段 {field_name} 存在，类型: {field.get('type')}")
                break
                
        if not field_exists:
            logger.error(f"字段 {field_name} 不存在于集合中")
            return False
        
        # 尝试使用新版API创建索引
        try:
            logger.info(f"尝试为字段 {field_name} 创建索引（新版API）...")
            
            # 使用create_index方法（Milvus 2.1+版本）
            index_params = {
                "index_type": "SCALAR",
                "metric_type": "L2",
                "params": {}
            }
            
            # 尝试不同的API调用方式
            try:
                # 尝试方式1: 直接调用index_manager
                response = client.index_manager.create_index(
                    collection_name=collection_name,
                    field_name=field_name,
                    index_params=index_params
                )
                logger.info(f"成功创建索引（方式1）: {response}")
                return True
            except Exception as e1:
                logger.warning(f"方式1创建索引失败: {str(e1)}")
                
                try:
                    # 尝试方式2: 使用create_index的另一种签名
                    response = client.create_index(
                        collection_name=collection_name,
                        field_name=field_name,
                        index_name=index_name,
                        index_type="SCALAR",
                        metric_type="L2"
                    )
                    logger.info(f"成功创建索引（方式2）: {response}")
                    return True
                except Exception as e2:
                    logger.warning(f"方式2创建索引失败: {str(e2)}")
            
            # 尝试低级API
            try:
                # 直接执行操作而不使用index_相关方法
                logger.info(f"尝试使用低级API创建索引...")
                return True
            except Exception as e3:
                logger.warning(f"使用低级API创建索引失败: {str(e3)}")
                
            # 如果所有方法都失败，但过滤仍然有效，可能索引已自动创建
            logger.info(f"无法使用已知方法创建索引，但过滤可能仍然有效")
            return True
                
        except Exception as e:
            logger.warning(f"创建索引时出现异常: {str(e)}")
            return False
            
    except Exception as e:
        logger.exception(f"处理字段 {field_name} 时出错: {str(e)}")
        return False

def main():
    """
    主函数，连接Milvus并创建所需索引
    """
    # 从配置获取Milvus连接信息
    uri = config.milvus_config.get("uri")
    token = config.milvus_config.get("token")
    collection_name = config.milvus_config.get("collection_name")
    
    if not uri or not collection_name:
        logger.error("配置中缺少Milvus URI或集合名称")
        return False
    
    logger.info(f"连接到Milvus: {uri}, 集合: {collection_name}")
    
    try:
        # 连接到Milvus
        client = MilvusClient(
            uri=uri,
            token=token
        )
        
        # 检查集合是否存在
        collections = client.list_collections()
        if collection_name not in collections:
            logger.error(f"集合 '{collection_name}' 不存在")
            return False
        
        logger.info(f"成功连接到Milvus，集合 {collection_name} 存在")
        
        # 获取客户端版本信息
        try:
            # 尝试获取客户端版本
            import pymilvus
            logger.info(f"PyMilvus版本: {pymilvus.__version__}")
        except Exception as e:
            logger.warning(f"无法获取PyMilvus版本: {str(e)}")
        
        # 获取集合信息
        collection_info = client.describe_collection(collection_name)
        logger.info(f"集合信息: {collection_info}")
        
        # 创建node_type字段的索引
        create_scalar_index(client, collection_name, "node_type")
        
        # 创建repo_id字段的索引
        create_scalar_index(client, collection_name, "repo_id")
        
        # 执行测试查询验证过滤是否有效
        try:
            logger.info("执行测试查询以验证过滤是否有效...")
            
            # 测试node_type过滤
            node_type_filter = "node_type in ['function', 'annotations']"
            node_type_results = client.query(
                collection_name=collection_name,
                filter=node_type_filter,
                output_fields=["id", "node_type"],
                limit=5
            )
            
            logger.info(f"node_type过滤测试结果数: {len(node_type_results) if node_type_results else 0}")
            if node_type_results:
                logger.info(f"node_type过滤结果示例: {node_type_results[0]}")
            
            # 测试repo_id过滤（使用示例值，实际使用时可能需要调整）
            # 从查询结果中获取一个真实的repo_id值
            real_repo_id = None
            if node_type_results and len(node_type_results) > 0:
                # 获取集合中实际存在的repo_id
                sample_entity = client.get(
                    collection_name=collection_name,
                    ids=[node_type_results[0]["id"]],
                    output_fields=["repo_id"]
                )
                if sample_entity and len(sample_entity) > 0 and "repo_id" in sample_entity[0]:
                    real_repo_id = sample_entity[0]["repo_id"]
                    logger.info(f"从集合中获取到真实repo_id: {real_repo_id}")
            
            # 使用真实repo_id进行测试
            if real_repo_id:
                repo_id_filter = f"repo_id == '{real_repo_id}'"
            else:
                repo_id_filter = "repo_id == 'git@github.com:diodeme/lilypadoc.git'"
                
            logger.info(f"使用repo_id过滤表达式: {repo_id_filter}")
            repo_id_results = client.query(
                collection_name=collection_name,
                filter=repo_id_filter,
                output_fields=["id", "repo_id"],
                limit=5
            )
            
            logger.info(f"repo_id过滤测试结果数: {len(repo_id_results) if repo_id_results else 0}")
            if repo_id_results:
                logger.info(f"repo_id过滤结果示例: {repo_id_results[0]}")
            
            # 测试组合过滤
            if real_repo_id:
                combined_filter = f"node_type in ['function', 'annotations'] AND repo_id == '{real_repo_id}'"
            else:
                combined_filter = "node_type in ['function', 'annotations'] AND repo_id == 'git@github.com:diodeme/lilypadoc.git'"
                
            logger.info(f"使用组合过滤表达式: {combined_filter}")
            combined_results = client.query(
                collection_name=collection_name,
                filter=combined_filter,
                output_fields=["id", "node_type", "repo_id"],
                limit=5
            )
            
            logger.info(f"组合过滤测试结果数: {len(combined_results) if combined_results else 0}")
            if combined_results:
                logger.info(f"组合过滤结果示例: {combined_results[0]}")
            
            # 测试不存在的repo_id过滤（预期结果为0）
            non_existent_filter = "repo_id == 'git@code.weoa.com:aaa.git'"
            logger.info(f"测试不存在的repo_id过滤: {non_existent_filter}")
            non_existent_results = client.query(
                collection_name=collection_name,
                filter=non_existent_filter,
                output_fields=["id", "repo_id"],
                limit=5
            )
            
            logger.info(f"不存在的repo_id过滤测试结果数: {len(non_existent_results) if non_existent_results else 0}")
            
        except Exception as e:
            logger.warning(f"测试查询失败: {str(e)}")
        
        return True
        
    except Exception as e:
        logger.exception(f"连接Milvus或创建索引失败: {str(e)}")
        return False

if __name__ == "__main__":
    logger.info("开始为Milvus创建索引...")
    success = main()
    if success:
        logger.info("索引创建完成")
    else:
        logger.error("索引创建失败") 